#------------------------------------------------------------------------------
# plot trends over time
# Figure 1. Pain diagnosis
# Figure S1. Enrollment 
# Figure 2. Trends among pain patients
#==============================================================================

load_library = c('bit64','data.table','fst','future.apply','stringr','logger','vroom')
invisible(lapply(load_library, function(x) library(x, character.only=TRUE, quietly= TRUE)))

library(ggplot2)
library(ggpubr)

bucket = '/Users/bk/Dropbox/project/OP/- posted/2021-06-08_covid_substitution_JAMA/dataverse'
data_path = file.path(bucket, 'data')

infile = file.path(data_path,'processed_data','ses_division_weekly_covid_0101_0930.fst')

# ----------
logger::log_info('now loading data ...', infile)

dt = read_fst(infile, as.data.table = TRUE)

dt[, c('year','week') := tstrsplit(date, 'W',fixed=TRUE)]
dt[, week := as.numeric(gsub('W','',week))]
dt[, date := as.Date(paste(year, week, 1, sep="-"), "%Y-%U-%u")]
dt[, month := month(date)]

pre_condition_list = c('all','backpain','pain','ord')
for (var in pre_condition_list){
	dt[[paste0(var,'_','sum_sum_opioid_quantity')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_quantity')]])
	dt[[paste0(var,'_','sum_sum_opioid_days')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_days')]])
	dt[[paste0(var,'_','sum_sum_opioid_mme')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_mme')]])
	dt[[paste0(var,'_','mean_sum_opioid_quantity')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_quantity')]])
	dt[[paste0(var,'_','mean_sum_opioid_days')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_days')]])
	dt[[paste0(var,'_','mean_sum_opioid_mme')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_mme')]])
}

# drop the first week
dt = dt[week > 1, ]

#--------------
# 1. figure S1. enrollment data
#==============

sum_target_var = c('N')
sum_dt = dt[, lapply(.SD, sum, na.rm=TRUE), by='date', .SDcols=sum_target_var] 
sum_dt[, date := as.Date(date)]

sum_dt[, year := year(date)]
sum_dt[, date := as.Date(gsub('2019','2020',date))]
sum_dt[,year := factor(year, levels = c(2019,2020))]

p = ggplot(sum_dt[date <= as.Date('2020-09-30'),], aes(x=date,y=N/1000000,
	group=year,color=year)) +
	geom_line() + 
	geom_point() + 
	scale_x_date(date_breaks = '1 month', date_labels = '%b')+
	scale_color_manual(name='year',values = c('2019'='black','2020'='red'))+
	ylim(c(0,20))+
	theme_bw() +
	theme(legend.position = 'top')+
	labs(y='Total Number of Enrolled Patients (in million)',x=NULL)

ggsave(file.path(data_path,'figure_s1_N_enrolled_trends.eps'), p, width=6, height=4)
